Find leaves of binary tree

Time: O(N); Space: O(H); medium

Given a binary tree, collect a tree’s nodes as if you were doing this:

  • Collect and remove all leaves,

  • repeat until the tree is empty.

Example 1:

    1
   / \
  2   3
 / \
4   5

Input: root = {TreeNode} [1,2,3,4,5]

Output: [[4, 5, 3], [2], [1]]

Example 2:

    1
   / \
  2   3
 /
4

Input: root = {TreeNode} [1,2,3,4]

Output: [[4, 3], [2], [1]]

[1]:
class TreeNode(object):
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None
[2]:
class Solution1(object):
    """
    Time: O(N)
    Space: O(H)
    """
    def findLeaves(self, root):
        """
        :type root: TreeNode
        :rtype: List[List[int]]
        """
        def findLeavesHelper(node, result):
            if not node:
                return -1
            level = 1 + max(findLeavesHelper(node.left, result), \
                            findLeavesHelper(node.right, result))

            if len(result) < level + 1:
                result.append([])
            result[level].append(node.val)
            return level

        result = []
        findLeavesHelper(root, result)

        return result
[3]:
s = Solution1()

root = TreeNode(1)
root.left, root.right = TreeNode(2), TreeNode(3)
root.left.left, root.left.right = TreeNode(4), TreeNode(5)
assert s.findLeaves(root) == [[4, 5, 3], [2], [1]]

root = TreeNode(1)
root.left, root.right = TreeNode(2), TreeNode(3)
root.left.left = TreeNode(4)
assert s.findLeaves(root) == [[4, 3], [2], [1]]